{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp losses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Losses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">Losses not available in fastai or Pytorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from fastai.losses import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "## Available in Pytorch 1.9\n",
    "class HuberLoss(nn.Module):\n",
    "    \"\"\"Huber loss \n",
    "    \n",
    "    Creates a criterion that uses a squared term if the absolute\n",
    "    element-wise error falls below delta and a delta-scaled L1 term otherwise.\n",
    "    This loss combines advantages of both :class:`L1Loss` and :class:`MSELoss`; the\n",
    "    delta-scaled L1 region makes the loss less sensitive to outliers than :class:`MSELoss`,\n",
    "    while the L2 region provides smoothness over :class:`L1Loss` near 0. See\n",
    "    `Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`_ for more information.\n",
    "    This loss is equivalent to nn.SmoothL1Loss when delta == 1.\n",
    "    \"\"\"\n",
    "    def __init__(self, reduction='mean', delta=1.0):\n",
    "        assert reduction in ['mean', 'sum', 'none'], \"You must set reduction to 'mean', 'sum' or 'none'\"\n",
    "        self.reduction, self.delta = reduction, delta\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n",
    "        diff = input - target\n",
    "        abs_diff = torch.abs(diff)\n",
    "        mask = abs_diff < self.delta\n",
    "        loss = torch.cat([(.5*diff[mask]**2), self.delta * (abs_diff[~mask] - (.5 * self.delta))])\n",
    "        if self.reduction == 'mean':\n",
    "            return loss.mean()\n",
    "        elif self.reduction == 'sum':\n",
    "            return loss.sum()\n",
    "        else: \n",
    "            return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class LogCoshLoss(nn.Module):\n",
    "    def __init__(self, reduction='mean', delta=1.0):\n",
    "        assert reduction in ['mean', 'sum', 'none'], \"You must set reduction to 'mean', 'sum' or 'none'\"\n",
    "        self.reduction, self.delta = reduction, delta\n",
    "        super().__init__()\n",
    "        \n",
    "    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n",
    "        loss = torch.log(torch.cosh(input - target + 1e-12))\n",
    "        if self.reduction == 'mean':\n",
    "            return loss.mean()\n",
    "        elif self.reduction == 'sum':\n",
    "            return loss.sum()\n",
    "        else: \n",
    "            return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4588)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inp = torch.rand(8, 3, 10)\n",
    "targ = torch.randn(8, 3, 10)\n",
    "test_close(HuberLoss(delta=1)(inp, targ), nn.SmoothL1Loss()(inp, targ))\n",
    "LogCoshLoss()(inp, targ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class MaskedLossWrapper(Module):\n",
    "    def __init__(self, crit):\n",
    "        self.loss = crit\n",
    "\n",
    "    def forward(self, inp, targ):\n",
    "        inp = inp.flatten(1)\n",
    "        targ = targ.flatten(1)\n",
    "        mask = torch.isnan(targ)\n",
    "        inp, targ = inp[~mask], targ[~mask]\n",
    "        return self.loss(inp, targ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(nan), tensor(1.0520))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inp = torch.rand(8, 3, 10)\n",
    "targ = torch.randn(8, 3, 10)\n",
    "targ[targ >.8] = np.nan\n",
    "nn.L1Loss()(inp, targ), MaskedLossWrapper(nn.L1Loss())(inp, targ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class CenterLoss(Module):\n",
    "    r\"\"\"\n",
    "    Code in Pytorch has been slightly modified from: https://github.com/KaiyangZhou/pytorch-center-loss/blob/master/center_loss.py\n",
    "    Based on paper: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.\n",
    "\n",
    "    Args:\n",
    "        c_out (int): number of classes.\n",
    "        logits_dim (int): dim 1 of the logits. By default same as c_out (for one hot encoded logits)\n",
    "        \n",
    "    \"\"\"\n",
    "    def __init__(self, c_out, logits_dim=None):\n",
    "        logits_dim = ifnone(logits_dim, c_out)\n",
    "        self.c_out, self.logits_dim = c_out, logits_dim\n",
    "        self.centers = nn.Parameter(torch.randn(c_out, logits_dim))\n",
    "        self.classes = nn.Parameter(torch.arange(c_out).long(), requires_grad=False)\n",
    "\n",
    "    def forward(self, x, labels):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x: feature matrix with shape (batch_size, logits_dim).\n",
    "            labels: ground truth labels with shape (batch_size).\n",
    "        \"\"\"\n",
    "        bs = x.shape[0]\n",
    "        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(bs, self.c_out) + \\\n",
    "                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.c_out, bs).T\n",
    "        distmat = torch.addmm(distmat, x, self.centers.T, beta=1, alpha=-2)\n",
    "\n",
    "        labels = labels.unsqueeze(1).expand(bs, self.c_out)\n",
    "        mask = labels.eq(self.classes.expand(bs, self.c_out))\n",
    "\n",
    "        dist = distmat * mask.float()\n",
    "        loss = dist.clamp(min=1e-12, max=1e+12).sum() / bs\n",
    "\n",
    "        return loss\n",
    "\n",
    "\n",
    "class CenterPlusLoss(Module):\n",
    "    \n",
    "    def __init__(self, loss, c_out, λ=1e-2, logits_dim=None):\n",
    "        self.loss, self.c_out, self.λ = loss, c_out, λ\n",
    "        self.centerloss = CenterLoss(c_out, logits_dim)\n",
    "        \n",
    "    def forward(self, x, labels):\n",
    "        return self.loss(x, labels) + self.λ * self.centerloss(x, labels)\n",
    "    def __repr__(self): return f\"CenterPlusLoss(loss={self.loss}, c_out={self.c_out}, λ={self.λ})\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(9.2481, grad_fn=<DivBackward0>),\n",
       " TensorBase(2.3559, grad_fn=<AliasBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c_in = 10\n",
    "x = torch.rand(64, c_in).to(device=default_device())\n",
    "x = F.softmax(x, dim=1)\n",
    "label = x.max(dim=1).indices\n",
    "CenterLoss(c_in).to(x.device)(x, label), CenterPlusLoss(LabelSmoothingCrossEntropyFlat(), c_in).to(x.device)(x, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CenterPlusLoss(loss=FlattenedLoss of LabelSmoothingCrossEntropy(), c_out=10, λ=0.01)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "CenterPlusLoss(LabelSmoothingCrossEntropyFlat(), c_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class FocalLoss(Module):\n",
    "    \"\"\" Weighted, multiclass focal loss\"\"\"\n",
    "\n",
    "    def __init__(self, alpha:Optional[Tensor]=None, gamma:float=2., reduction:str='mean'):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            alpha (Tensor, optional): Weights for each class. Defaults to None.\n",
    "            gamma (float, optional): A constant, as described in the paper. Defaults to 2.\n",
    "            reduction (str, optional): 'mean', 'sum' or 'none'. Defaults to 'mean'.\n",
    "        \"\"\"\n",
    "        self.alpha, self.gamma, self.reduction = alpha, gamma, reduction\n",
    "        self.nll_loss = nn.NLLLoss(weight=alpha, reduction='none')\n",
    "\n",
    "    def forward(self, x: Tensor, y: Tensor) -> Tensor:\n",
    "\n",
    "        log_p = F.log_softmax(x, dim=-1)\n",
    "        pt = log_p[torch.arange(len(x)), y].exp()\n",
    "        ce = self.nll_loss(log_p, y)\n",
    "        loss = (1 - pt) ** self.gamma * ce\n",
    "\n",
    "        if self.reduction == 'mean':\n",
    "            loss = loss.mean()\n",
    "        elif self.reduction == 'sum':\n",
    "            loss = loss.sum()\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.9829)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = torch.normal(0, 2, (16, 2)).to(device=default_device())\n",
    "targets = torch.randint(0, 2, (16,)).to(device=default_device())\n",
    "FocalLoss()(inputs, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TweedieLoss(Module):\n",
    "    def __init__(self, p=1.5, eps=1e-8):\n",
    "        \"\"\"\n",
    "        Tweedie loss as calculated in LightGBM\n",
    "        Args:\n",
    "            p: tweedie variance power (1 < p < 2)\n",
    "            eps: small number to avoid log(zero).\n",
    "        \"\"\"\n",
    "        assert 1 < p < 2, \"make sure 1 < p < 2\" \n",
    "        self.p, self.eps = p, eps\n",
    "\n",
    "    def forward(self, inp, targ):\n",
    "        \"Poisson and compound Poisson distribution, targ >= 0, inp > 0\"\n",
    "        inp = inp.flatten()\n",
    "        targ = targ.flatten()\n",
    "        torch.clamp_min_(inp, self.eps)\n",
    "        a = targ * torch.exp((1 - self.p) * torch.log(inp)) / (1 - self.p)\n",
    "        b = torch.exp((2 - self.p) * torch.log(inp)) / (2 - self.p)\n",
    "        loss = -a + b\n",
    "        return loss.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(3.0539)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c_in = 10\n",
    "output = torch.rand(64).to(device=default_device())\n",
    "target = torch.rand(64).to(device=default_device())\n",
    "TweedieLoss().to(output.device)(output, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/050_losses.ipynb saved at 2022-11-09 12:52:08\n",
      "Correct notebook to script conversion! 😃\n",
      "Wednesday 09/11/22 12:52:11 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
